{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Omniglot 数据集单样本学习：MAML 算法\n",
    "\n",
    "### One Shot Learning on Omniglot: MAML Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path\n",
    "\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.autograd\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.nn import functional as F\n",
    "import torch.nn.init as init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 配置 Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_size = 28 # 图片尺度大小 size of each image\n",
    "way_count = 5 # 类别数 number of classes in each task\n",
    "shot_count = 1 # 支持集样本数 sample number in each support set\n",
    "query_count = 15 # 查询集样本数 sample number in each query set\n",
    "train_batch_size = 32 # 每次训练使用的任务数 task number for training\n",
    "test_batch_size = 200 # 每次测试使用的任务数 task number for testing\n",
    "train_step_count = 5 # 训练时采用的优化步数 gradient descent steps for training\n",
    "test_step_count = 10 # 测试时采用的优化步数 gradient descent steps for testing\n",
    "train_epoch_count = 801 # 训练回合数 total training epoch\n",
    "test_epoch_interval = 200 # 每训练多少次测试一次 test frequency"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据  Read Data\n",
    "\n",
    "请下载以下两个文件并解压到2个不同的文件夹。\n",
    "\n",
    "Please download and unzip the following two zip files as two different folders.\n",
    "\n",
    "\n",
    "```\n",
    "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip\n",
    "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip\n",
    "```\n",
    "\n",
    "文件路径作为下列 `get_raw_data()` 函数的首个参数。\n",
    "\n",
    "The unzipped folders be the first parameters of the following function `get_raw_data()`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 964 characters\n",
      "Found 659 characters\n"
     ]
    }
   ],
   "source": [
    "def get_raw_data(directory, image_size): # 载入图片并做预处理\n",
    "    characters = []\n",
    "    for char_dir, dirs, file_names in os.walk(directory):\n",
    "        if [_ for filename in file_names if filename.endswith('.png')]:\n",
    "            images = []\n",
    "            for file_name in file_names:\n",
    "                file_path = os.path.join(char_dir, file_name)\n",
    "                image = Image.open(file_path).convert('L')\n",
    "                image = image.resize((image_size, image_size))\n",
    "                image = np.reshape(image, (1, image_size, image_size))\n",
    "                image = image / 255.\n",
    "                images.append(image)\n",
    "            images = np.array(images)\n",
    "            characters.append((char_dir, images))\n",
    "    print('Found {} characters'.format(len(characters)))\n",
    "    return characters\n",
    "\n",
    "\n",
    "raw_data = {}\n",
    "raw_data['train'] = get_raw_data(\n",
    "        '.\\omniglot\\images_background', image_size=image_size)\n",
    "raw_data['test'] = get_raw_data(\n",
    "        '.\\omniglot\\mages_evaluation', image_size=image_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "随机读取单样本分类任务的函数\n",
    "\n",
    "Read tasks randomly for the one-shot classification tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_random_task(raw_data,\n",
    "        way_count, shot_count, query_count, permute=True):\n",
    "    s_inputs, s_labels = [], []\n",
    "    q_inputs, q_labels = [], []\n",
    "    chars = np.random.choice(len(raw_data), way_count, replace=False)\n",
    "    for char_index, char in enumerate(chars):\n",
    "        images = raw_data[char][1]\n",
    "        image_indices = np.random.choice(images.shape[0],\n",
    "                shot_count + query_count, replace=False)\n",
    "        s_inputs.append(images[image_indices[:shot_count]])\n",
    "        q_inputs.append(images[image_indices[shot_count:]])\n",
    "        s_labels += [char_index,] * shot_count\n",
    "        q_labels += [char_index,] * query_count\n",
    "    s_inputs, s_labels = np.concatenate(s_inputs), np.array(s_labels)\n",
    "    q_inputs, q_labels = np.concatenate(q_inputs), np.array(q_labels)\n",
    "\n",
    "    if permute:\n",
    "        s_perms = np.random.permutation(way_count * shot_count)\n",
    "        s_inputs, s_labels = s_inputs[s_perms], s_labels[s_perms]\n",
    "        q_perms = np.random.permutation(way_count * query_count)\n",
    "        q_inputs, q_labels = q_inputs[q_perms], q_labels[q_perms]\n",
    "    return s_inputs, s_labels, q_inputs, q_labels\n",
    "\n",
    "\n",
    "def get_random_tasks(raw_data, batch_size,\n",
    "        way_count, shot_count, query_count, permute=True):\n",
    "    tasks = []\n",
    "    for idx in range(batch_size):\n",
    "        task = get_random_task(raw_data,\n",
    "            way_count, shot_count, query_count, permute=True)\n",
    "        tasks.append(task)\n",
    "    return tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  神经网络 Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "net0 = nn.Sequential(\n",
    "        nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=0),\n",
    "        nn.ReLU(),\n",
    "        nn.BatchNorm2d(num_features=64),\n",
    "        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),\n",
    "        nn.ReLU(),\n",
    "        nn.BatchNorm2d(num_features=64),\n",
    "        nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=0),\n",
    "        nn.ReLU(),\n",
    "        nn.BatchNorm2d(num_features=64),\n",
    "        nn.Conv2d(64, 64, kernel_size=2, stride=1, padding=0),\n",
    "        nn.ReLU(),\n",
    "        nn.BatchNorm2d(num_features=64),\n",
    "        nn.Flatten(),\n",
    "        nn.Linear(64, way_count),\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化网络 Initialize the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))\n",
       "  (1): ReLU()\n",
       "  (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n",
       "  (4): ReLU()\n",
       "  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))\n",
       "  (7): ReLU()\n",
       "  (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (9): Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))\n",
       "  (10): ReLU()\n",
       "  (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (12): Flatten()\n",
       "  (13): Linear(in_features=64, out_features=5, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def weights_init(m):\n",
    "    if type(m) in [nn.Conv2d, nn.ConvTranspose2d]:\n",
    "        init.kaiming_normal_(m.weight)\n",
    "        init.constant_(m.bias, 0.)\n",
    "    elif type(m) in [nn.Linear,]:\n",
    "        init.kaiming_normal_(m.weight)\n",
    "        init.constant_(m.bias, 0.)\n",
    "    elif type(m) in [nn.BatchNorm2d,]:\n",
    "        init.constant_(m.weight, 1.)\n",
    "        init.constant_(m.bias, 0.)\n",
    "\n",
    "net0.apply(weights_init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MAML 算法中需要构建梯度下降的计算图并用梯度下降后的网络变量作为网络权重，所以仅仅使用 `nn.Sequential` 构造得到的 `net0` 是不够的。下面的类扩展了 `nn.Sequential` 类，使得网络能根据外部变量进行计算。\n",
    "\n",
    "Since MAML algorithm calculates gradient descent and uses the resulting variables as the network parameters, the `nn.Sequential` instance `net0` does not suffice to support the algorithm. The following class extends the class `Network` so that the network parameter is configurable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConfigurableSequential(nn.Module):\n",
    "    \n",
    "    def __init__(self, net):\n",
    "        super(ConfigurableSequential, self).__init__()\n",
    "        self.net = net\n",
    "    \n",
    "    def forward(self, x, parameters=None, buffers=None, bn_training=True):\n",
    "        if parameters is None:\n",
    "            parameters = list(self.net.parameters()) # 所有要优化的变量\n",
    "        if buffers is None:\n",
    "            buffers = list(self.net.buffers())  # BatchNorm维护的均值和方差\n",
    "        \n",
    "        param_index, buffer_index = 0, 0\n",
    "        for m in self.net.modules():\n",
    "            if type(m) == nn.Sequential:\n",
    "                pass\n",
    "            elif type(m) in [nn.Conv2d,]:\n",
    "                weights, bias = parameters[param_index], parameters[param_index + 1]\n",
    "                param_index += 2\n",
    "                x = F.conv2d(x, weights, bias, stride=m.stride, padding=m.padding)\n",
    "            elif type(m) in [nn.Linear,]:\n",
    "                weights, bias = parameters[param_index], parameters[param_index + 1]\n",
    "                param_index += 2\n",
    "                x = F.linear(x, weights, bias)\n",
    "            elif type(m) in [nn.BatchNorm2d,]:\n",
    "                weights, bias = parameters[param_index], parameters[param_index + 1]\n",
    "                param_index += 2\n",
    "                running_mean, running_var = buffers[buffer_index], buffers[buffer_index+1]\n",
    "                buffer_index += 3\n",
    "                x = F.batch_norm(x, running_mean, running_var,\n",
    "                        weight=weights, bias=bias, training=bn_training)\n",
    "            elif type(m) in [nn.Flatten,]:\n",
    "                x = x.view(x.size(0), -1)\n",
    "            elif type(m) in [nn.ReLU,]:\n",
    "                x = F.relu(x, inplace=m.inplace)\n",
    "            else:\n",
    "                raise NotImplementedError\n",
    "        return x\n",
    "\n",
    "\n",
    "net = ConfigurableSequential(net0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MAML 算法\n",
    "MAML Algorithm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "meta_optimizer = optim.Adam(net.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run(net, tasks, step_count, criterion,\n",
    "        update_lr=0.4, meta_optimizer=None):\n",
    "    \n",
    "    q_corrects = np.zeros(shape=(step_count + 1,))\n",
    "    q_totals = np.zeros(shape=(step_count + 1,))\n",
    "    q_losses = np.zeros(shape=(step_count + 1,))\n",
    "    if meta_optimizer:\n",
    "        maml_losses = []\n",
    "    \n",
    "    for task in tasks:\n",
    "        s_inputs, s_labels, q_inputs, q_labels = task\n",
    "        s_inputs = torch.from_numpy(s_inputs).float()\n",
    "        s_labels = torch.from_numpy(s_labels).long()\n",
    "        q_inputs = torch.from_numpy(q_inputs).float()\n",
    "        q_labels = torch.from_numpy(q_labels).long()\n",
    "        \n",
    "        variables = list(net.parameters())\n",
    "        for k in range(step_count + 1):\n",
    "            if k:\n",
    "                # 梯度下降算法 gradient descent\n",
    "                s_logits = net(s_inputs, variables)\n",
    "                s_loss = criterion(s_logits, s_labels)\n",
    "                grads = torch.autograd.grad(s_loss, variables)\n",
    "                variables = [v - update_lr * g for g, v in zip(grads, variables)]\n",
    "\n",
    "            with torch.no_grad():\n",
    "                # 计算损失 calculate loss\n",
    "                q_logits = net(q_inputs, variables)\n",
    "                q_loss = criterion(q_logits, q_labels)\n",
    "                q_losses[k] += q_loss.item()\n",
    "                \n",
    "                # 计算准确度 calculate accurate\n",
    "                q_preds = F.softmax(q_logits, dim=1).argmax(dim=1)\n",
    "                q_correct = torch.eq(q_preds, q_labels).sum().item()  # convert to numpy\n",
    "                q_corrects[k] += q_correct\n",
    "                q_totals[k] += q_labels.shape[0]\n",
    "        \n",
    "        if meta_optimizer:\n",
    "            # 计算元学习损失 calculate loss for MAML\n",
    "            q_logits = net(q_inputs, variables)\n",
    "            maml_loss = criterion(q_logits, q_labels)\n",
    "            maml_losses.append(maml_loss)\n",
    "    \n",
    "    if meta_optimizer: # 更新 MAML update\n",
    "        maml_loss = torch.mean(torch.stack(maml_losses))\n",
    "        meta_optimizer.zero_grad()\n",
    "        maml_loss.backward()\n",
    "        meta_optimizer.step()\n",
    "    \n",
    "    q_accs = q_corrects / q_totals\n",
    "    q_losses = q_losses / len(tasks)\n",
    "    return q_accs, q_losses"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练与测试 Train & test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 : train accuracy [0.18958333 0.30875    0.39791667 0.41541667 0.41458333 0.41541667]\n",
      "epoch 0 : test accuracy [0.195      0.3225     0.40625    0.40958333 0.40916667 0.41\n",
      " 0.41041667 0.41041667 0.41041667 0.41166667 0.41166667]\n",
      "epoch 50 : train accuracy [0.20833333 0.55541667 0.59083333 0.6025     0.60416667 0.6075    ]\n",
      "epoch 100 : train accuracy [0.1825     0.69291667 0.7275     0.72791667 0.72875    0.73      ]\n",
      "epoch 150 : train accuracy [0.17416667 0.755      0.77291667 0.78166667 0.78208333 0.78291667]\n",
      "epoch 200 : train accuracy [0.19958333 0.79708333 0.82625    0.82791667 0.82916667 0.83041667]\n",
      "epoch 200 : test accuracy [0.2025     0.79958333 0.82666667 0.83333333 0.835      0.83541667\n",
      " 0.83583333 0.83625    0.83541667 0.83541667 0.83541667]\n",
      "epoch 250 : train accuracy [0.22875    0.86333333 0.8775     0.88041667 0.88291667 0.88208333]\n",
      "epoch 300 : train accuracy [0.17458333 0.83625    0.84166667 0.84416667 0.84541667 0.84666667]\n",
      "epoch 350 : train accuracy [0.19125    0.86291667 0.87583333 0.8775     0.87791667 0.87875   ]\n",
      "epoch 400 : train accuracy [0.15208333 0.87333333 0.88916667 0.89083333 0.89333333 0.89625   ]\n",
      "epoch 400 : test accuracy [0.15208333 0.87458333 0.89541667 0.89833333 0.89958333 0.90083333\n",
      " 0.90166667 0.90208333 0.90291667 0.90333333 0.90291667]\n",
      "epoch 450 : train accuracy [0.185      0.89083333 0.89416667 0.89666667 0.89791667 0.8975    ]\n",
      "epoch 500 : train accuracy [0.17791667 0.87583333 0.88791667 0.88916667 0.89041667 0.89125   ]\n",
      "epoch 550 : train accuracy [0.20333333 0.89791667 0.91041667 0.91166667 0.91416667 0.91541667]\n",
      "epoch 600 : train accuracy [0.18625    0.90375    0.91958333 0.92       0.92166667 0.9225    ]\n",
      "epoch 600 : test accuracy [0.18583333 0.90375    0.92083333 0.92125    0.92291667 0.92458333\n",
      " 0.92458333 0.92458333 0.925      0.92416667 0.92416667]\n",
      "epoch 650 : train accuracy [0.2375     0.91666667 0.93625    0.93666667 0.93583333 0.93625   ]\n",
      "epoch 700 : train accuracy [0.16       0.89625    0.91125    0.91375    0.91666667 0.9175    ]\n",
      "epoch 750 : train accuracy [0.22541667 0.90041667 0.90458333 0.91       0.91       0.91083333]\n",
      "epoch 800 : train accuracy [0.20666667 0.91208333 0.925      0.92833333 0.92916667 0.93166667]\n",
      "epoch 800 : test accuracy [0.20541667 0.91166667 0.92625    0.92958333 0.93041667 0.93166667\n",
      " 0.93291667 0.93333333 0.93375    0.93375    0.935     ]\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(train_epoch_count):\n",
    "\n",
    "    # 训练 train\n",
    "    train_tasks = get_random_tasks(raw_data['train'], train_batch_size,\n",
    "            way_count, shot_count, query_count)\n",
    "    train_accs, _ = run(net, train_tasks, step_count=train_step_count,\n",
    "            criterion=criterion, meta_optimizer=meta_optimizer)\n",
    "    if epoch % 50 == 0: # 训练多次后输出一次训练结果\n",
    "        print('epoch {} : train accuracy {}'.format(epoch, train_accs))\n",
    "\n",
    "    # 测试 test\n",
    "    if epoch % test_epoch_interval == 0: # 训练多次后测试一次\n",
    "        test_tasks = get_random_tasks(raw_data['test'], test_batch_size,\n",
    "                way_count, shot_count, query_count)\n",
    "        test_accs, _ = run(net, train_tasks, step_count=test_step_count,\n",
    "                criterion=criterion)\n",
    "        print('epoch {} : test accuracy {}'.format(epoch, test_accs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "经过训练，测试集上的准确率由 40% 左右提升到超过 90%。\n",
    "\n",
    "After training, the accuracy in test tasks increase from ~40% to >90%."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
